import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import pandas as pd
from pathos.multiprocessing import Pool
from linearmodels import OLS
import datetime
from jm.library.data_helper import filter_data, date_to_str, is_between, is_weakly_less_than
from jm.library.likelihood import Likelihood
import warnings
warnings.filterwarnings("ignore")

rcParams['figure.figsize'] = (13.0, 7.0)
rcParams['lines.linewidth'] = 3
rcParams['font.size'] = 18
rcParams['text.usetex'] = False
rcParams['text.latex.preamble'] = r'\usepackage{amsmath}' #for \text command

np.random.seed(1234)

df_status = pd.read_csv('data/jesus_maria_status_deidentified.csv')
df_status_voluntary = pd.read_csv('data/jesus_maria_status_voluntary_deidentified.csv')


def get_cols_over_time(col_fmt='payments_by_{}'):
    return [col_fmt.format(date_to_str(d)) for d in Likelihood.event_dates]


def plot_over_time(df, col_fmt='payments_by_{}', func=np.sum, ax=None, linestyle='solid', normalizer=1):
    cols = get_cols_over_time(col_fmt)
    this_df = df[cols]/normalizer
    this_df.columns = Likelihood.event_dates
    ax = this_df.apply(func, axis=0).plot(ax=ax, linestyle=linestyle)
    return ax


# In[5]:


payment_cols = get_cols_over_time()
priority_cols = get_cols_over_time('priority_by_{}')
action_cols = get_cols_over_time('action_by_{}')
action_nomedida_cols = get_cols_over_time('action_no_medida_by_{}')
relative_payment_cols = ['relative_' + c for c in payment_cols]
fwdiff_payment_cols = ['fwdiff_' + p for p in payment_cols]
relative_fwdiff_payment_cols = ['relative_' + c for c in fwdiff_payment_cols]
binary_fwdiff_payment_cols = ['binary_' + c for c in fwdiff_payment_cols]
cumul_binary_fwdiff_payment_cols = ['cumul_binary_' + c for c in fwdiff_payment_cols]
positive_fwdiff_payment_cols = ['positive_' + c for c in fwdiff_payment_cols]
remaining_balance_cols = get_cols_over_time('remaining_balance_by_{}')
    
df_status[relative_payment_cols] = df_status[payment_cols].divide(
                df_status['total_due'], axis=0)

df_status[relative_fwdiff_payment_cols] = df_status[fwdiff_payment_cols].divide(
    df_status['total_due'], axis=0)

df_status[binary_fwdiff_payment_cols] = df_status[fwdiff_payment_cols]

for col in binary_fwdiff_payment_cols:
    df_status.loc[df_status[col] > 0, col] = 1

for j,col in enumerate(binary_fwdiff_payment_cols):
    df_status[cumul_binary_fwdiff_payment_cols[j]] = df_status[binary_fwdiff_payment_cols[0:j]].sum(axis=1)
    
df_status[positive_fwdiff_payment_cols] = df_status[fwdiff_payment_cols]>0

for j in range(0, len(remaining_balance_cols)):
    df_status[remaining_balance_cols[j]] = 4 * df_status['total_due']-df_status[payment_cols[j]]

df_status['quantile_total_due'] = pd.qcut(df_status['total_due'], 100, labels=False) / 100


formula = payment_cols[-1] + ' ~ assignment_to_treatment'

due_vars_list = ['arbitrios_due', 'predial_due', 'Q1_arbitrios_due', 'Q1_predial_due']
due_vars = 'arbitrios_due + predial_due + Q1_arbitrios_due + Q1_predial_due'
var_list = ['assignment_to_treatment'] + due_vars_list + ['score_exo_covariates', 'score_endo_covariates',
                                                          'prob_repayment_endo_covariates',
                                                          'is_female', 'is_male', 'is_pricos',
                                                          'has_employer', 'has_education', 
                                                          'has_email', 'has_cellular', 'salary',
                                                          'last_year_share_repaid_by_3', 'age']

extra_var_list = ['is_age_imputed', 'is_last_year_share_repaid_by_3_imputed']

longer_var_list = var_list + extra_var_list

print_vars = ['Intercept', 'G1_medida', 'G1_rec1', 'action_medida',
              'action_rec1', 'positive_payments',
              'priority_G1', 'priority_G2',
              'priority_G3', 'prob_repayment_endo_covariates', 'relative_payments',
              'last_year_share_repaid_by_3']

print_vars_no_lastyear = [x for x in print_vars if x != 'last_year_share_repaid_by_3']

Q1_due_vars = 'Q1_arbitrios_due + Q1_predial_due'

formula_long = payment_cols[-1] + ' ~ assignment_to_treatment'
covariates_forbinary = ''
formula_long_nopricos = payment_cols[-1] + ' ~ assignment_to_treatment'

for var in var_list:
    formula_long = formula_long + ' + ' + var
    if (var!="prob_repayment_endo_covariates") & (var!="assignment_to_treatment"):
        covariates_forbinary = covariates_forbinary + ' + ' + var

covariates_forbinary_longer = covariates_forbinary
for var in extra_var_list:
    covariates_forbinary_longer = covariates_forbinary_longer + ' + ' + var



# In[7]:


df_treatment = filter_data(
        df_status, assignment_to_treatment=1)
df_control = filter_data(
        df_status, assignment_to_treatment=0)


# In[8]:


td_pctile99 = df_status['total_due'].quantile([0.99]).values[0]
td_pctile98 = df_status['total_due'].quantile([0.98]).values[0]
td_pctile95 = df_status['total_due'].quantile([0.95]).values[0]
td_pctile90 = df_status['total_due'].quantile([0.90]).values[0]
td_pctile80 = df_status['total_due'].quantile([0.80]).values[0]
td_pctile70 = df_status['total_due'].quantile([0.70]).values[0]

payment_pctile99 = df_status[payment_cols[-1]].quantile([0.99]).values[0]


# In[9]:


def is_ever(df, cols, value, col_names=Likelihood.event_dates):
    global REF_DF 
    global REF_COLS 
    global REF_VALUE 
    
    REF_DF = df
    REF_COLS = cols
    REF_VALUE = value

    with Pool() as pool:
        list_is = pool.map(
            lambda i: REF_DF.loc[:, REF_COLS[:i+1]].apply(lambda r: np.isin(REF_VALUE, r), axis=1, raw=True), 
            list(range(len(REF_COLS))))
    this_df = pd.concat(list_is, axis=1)
    this_df.columns = col_names
    return 1 * this_df

# # Binary Payments

# ## Summary Statistics

# ### Number of payment events in treatment v.s. control


with open("figs/tableOB1_actual.txt", "w") as text_file:
    num_events_control = sum(df_control[cumul_binary_fwdiff_payment_cols[-1]])
    num_events_treatment = sum(df_treatment[cumul_binary_fwdiff_payment_cols[-1]])
    print("Num Events in Control: " + str(num_events_control), file=text_file)
    print("Num Events in Treatment: " + str(num_events_treatment), file=text_file)


# ### Average Payment per Event

# In[24]:

with open("figs/tableOB2_actual.txt", "w") as text_file:
    print("Average Payment per Event in Treatment" + ": " + str(
        (df_treatment[payment_cols[-1]].sum()/
         df_treatment[cumul_binary_fwdiff_payment_cols[-1]].sum())), file=text_file)
    print("Average Payment per Event in Control" + ": " + str(
        (df_control[payment_cols[-1]].sum()/
         df_control[cumul_binary_fwdiff_payment_cols[-1]].sum())), file=text_file)

    df_treatment_td99 = df_treatment[df_treatment['total_due']<=td_pctile99]
    df_control_td99 = df_control[df_control['total_due']<=td_pctile99]

    print("Average Payment per Event in Treatment (bottom 99th percentile total due)" + ": " + str(
        (df_treatment_td99[payment_cols[-1]].sum()/
         df_treatment_td99[cumul_binary_fwdiff_payment_cols[-1]].sum())), file=text_file)
    print("Average Payment per Event in Control (bottom 99th percentile total due)" + ": " + str(
        (df_control_td99[payment_cols[-1]].sum()/
         df_control_td99[cumul_binary_fwdiff_payment_cols[-1]].sum())), file=text_file)


df_status['is_male'] = df_status['sex']=='M'
df_status['is_female'] = df_status['sex']=='F'
df_status['is_person'] = df_status['is_person'].astype(bool)
df_status['is_sex_missing'] = df_status['sex'].isna()
df_status.reset_index(inplace=True)

var = 'last_year_share_repaid_by_3'
avg_var = 'avg_' + var
df_status[avg_var] = df_status[var].mean()
df_status.loc[df_status[var].isna(), var] = df_status[avg_var]
df_status.loc[
    np.isnan(df_status['prob_repayment_endo_covariates']), 'prob_repayment_endo_covariates'] = (
        df_status['score_endo_covariates'] / (df_status['score_endo_covariates'] + df_status['Q1_total_due']))
df_status.loc[
    np.isnan(df_status['prob_repayment_exo_covariates']), 'prob_repayment_exo_covariates'] = (
        df_status['score_exo_covariates'] / (df_status['score_exo_covariates'] + df_status['Q1_total_due']))

# Binary Payments: reshaping dataset, generating variables

# unpivot the date columns
vars_to_keep = ['id_scrambled', 'assignment_to_treatment',
                'prob_repayment_endo_covariates', 'has_education', 
                'is_person', 'is_male', 'is_female', 'total_due', 'is_pricos',
                'has_email', 'has_employer',
                'prob_repayment_exo_covariates',
                'last_year_share_repaid_by_3',
                'score_endo_covariates', 'score_exo_covariates', 
                'Q1_arbitrios_due', 'Q1_predial_due', 'has_cellular',
                'arbitrios_due', 'predial_due', 'salary',
                'age', 'quantile_total_due', 'days_from_G1_to_promise'] + extra_var_list


def generate_melted_df(value_name, cols_list):
    df = pd.melt(df_status[vars_to_keep + cols_list], 
                 id_vars=vars_to_keep, 
                 var_name='date', 
                 value_name=value_name)
    df['date'] = df['date'].str.replace(value_name + '_by_', '')
    return df


def gen_priority_indicator(df, priority):
    df['priority_' + priority] = 1.0*(df['priority']==priority)
    return df


def gen_action_indicator(df, action):
    df['action_' + action] = 1.0*(df['action']==action)
    return df


def gen_interaction_priority(df, var, priority):
    df[var + "_" + priority] =  1.0 * (df[var]) * (df['priority']==priority)
    return df


# In[33]:


binary_fwdiff_payments_melted = generate_melted_df(value_name='binary_fwdiff_payments', 
                                                   cols_list=binary_fwdiff_payment_cols)
payments_melted = generate_melted_df(value_name='relative_payments', 
                                     cols_list=relative_payment_cols)
actions_melted = generate_melted_df(value_name='action', 
                                    cols_list=action_cols)
priority_melted = generate_melted_df(value_name='priority', 
                                     cols_list=priority_cols)
df_binary_reg = binary_fwdiff_payments_melted.merge(actions_melted, on=vars_to_keep+['date'])
df_binary_reg = df_binary_reg.merge(priority_melted, on=vars_to_keep+['date'])
df_binary_reg = df_binary_reg.merge(payments_melted, on=vars_to_keep+['date'])
df_binary_reg['date'] = pd.to_datetime(df_binary_reg['date'])
df_binary_reg['week'] = 1 + (df_binary_reg['date']-datetime.datetime.strptime('2021-04-12', '%Y-%m-%d')).dt.days/7


# In[34]:


df_binary_reg['is_firm'] = 1-df_binary_reg['is_person']
for priority in ['G1', 'G2', 'G3', 'N']:
    df_binary_reg = gen_priority_indicator(df_binary_reg, priority)

for action in ['valor', 'rec1', 'medida', 'N']:
    df_binary_reg = gen_action_indicator(df_binary_reg, action)

df_binary_reg['G1_valor'] = (df_binary_reg['priority_G1'])*(df_binary_reg['action_valor'])
df_binary_reg['G1_rec1'] = (df_binary_reg['priority_G1'])*(df_binary_reg['action_rec1'])
df_binary_reg['G1_medida'] = (df_binary_reg['priority_G1'])*(df_binary_reg['action_medida'])
df_binary_reg['positive_payments'] = 1.0*(df_binary_reg['relative_payments']>0)

# ## Binary Payments: Reduced form regressions

benchmark_formula = ('binary_fwdiff_payments ~ relative_payments + positive_payments + priority_G1 + priority_G2 + '
                     'priority_G3 + action_rec1 + action_medida + prob_repayment_endo_covariates + G1_rec1 + G1_medida')


# Generate Table 5 (column 1)
formula = benchmark_formula
model = OLS.from_formula(formula, data=df_binary_reg)
res = model.fit(cov_type='robust')


# Generate Table 5 (column 2)
df_results = res.params[print_vars_no_lastyear].to_frame('Parameter 1')
df_results['Standard Error 1'] = res.std_errors[print_vars_no_lastyear]
df_results.loc['last_year_share_repaid_by_3', :] = np.NaN

formula = benchmark_formula + ' + last_year_share_repaid_by_3'
df_binary_reg_temp = df_binary_reg[~np.isnan(df_binary_reg['last_year_share_repaid_by_3'])]
model = OLS.from_formula(formula, data=df_binary_reg_temp)
res = model.fit(cov_type='robust')

df_results['Parameter 2'] = res.params[print_vars]
df_results['Standard Error 2'] = res.std_errors[print_vars]

# Print Table 5
with open("figs/table5.txt", "w") as text_file:
    text_file.write(df_results.to_latex(float_format="%.3f"))

# Generate Table OA1

formula = benchmark_formula + covariates_forbinary
model = OLS.from_formula(formula, data=df_binary_reg)
res = model.fit(cov_type='robust')

df_results_with_controls = res.params[print_vars].to_frame('Parameter with controls')
df_results_with_controls['Standard Error 2'] = res.std_errors[print_vars]

# Print Table OA1
with open("figs/tableOA1.txt", "w") as text_file:
    text_file.write(df_results_with_controls.to_latex(float_format="%.3f"))


# Generate Table OA1 expanded

formula = benchmark_formula + covariates_forbinary_longer
model = OLS.from_formula(formula, data=df_binary_reg)
res = model.fit(cov_type='robust')

df_results_with_expcontrols = res.params[print_vars].to_frame('Parameter with controls')
df_results_with_expcontrols['Standard Error 2'] = res.std_errors[print_vars]

# Print Table OA1
with open("figs/tableOA1_missing.txt", "w") as text_file:
    text_file.write(df_results_with_expcontrols.to_latex(float_format="%.3f"))



# Measuring medida in control by score/total due
quantile_list_lower = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
total_due_list_lower = df_control['total_due'].quantile(quantile_list_lower)
total_due_list_upper = df_control['total_due'].quantile(quantile_list_lower[1:] + [1])

quintile_list_lower = [0,0.2, 0.4, 0.6, 0.8]
total_due_quintile_list_lower = df_control['total_due'].quantile(quintile_list_lower)
total_due_quintile_list_upper = df_control['total_due'].quantile(quintile_list_lower[1:] + [1])


# In[100]:


def first_time_event(df, varname, value, cols, ne=False):
    df[varname] = np.nan
    if ne:
        df.loc[
            df[cols[0]]!=value, varname] = Likelihood.event_dates[0]
        for j in range(1, len(Likelihood.event_dates)):
            df.loc[(df[cols[j-1]]==value) & 
                   (df[cols[j]]!=value) &  
                   (df[varname].isnull()), 
                   varname] = Likelihood.event_dates[j]
    else:
        df.loc[
            df[cols[0]]==value, varname] = Likelihood.event_dates[0]
        for j in range(1, len(Likelihood.event_dates)):
            df.loc[(df[cols[j-1]]!=value) & 
                   (df[cols[j]]==value) &  
                   (df[varname].isnull()), 
                   varname] = Likelihood.event_dates[j]


# In[101]:


def get_time_to_medida_from_rec1(df, n=8):
    first_time_event(df, 'first_time_rec1', "rec1", action_cols)
    first_time_event(df, 'first_time_medida', "medida", action_cols)
    first_time_event(df, 'first_time_payment', 0, payment_cols, ne=True)
    df.loc[
        (~df['first_time_medida'].isna()) & 
        (df['first_time_rec1'].isna()), 'first_time_rec1'] = df['first_time_medida']
    df['time_to_medida_from_rec1_le' + str(n) + 'weeks'] = (
        pd.to_datetime(df['first_time_medida']) - 
        pd.to_datetime(df['first_time_rec1'])).astype('timedelta64[D]') <= 7*n
    df.loc[df['first_time_rec1'].isna(), 'time_to_medida_from_rec1_le' + str(n) + 'weeks'] = np.nan
    df['time_to_payment_from_rec1_le' + str(n) + 'weeks'] = (
        pd.to_datetime(df['first_time_payment']) - 
        pd.to_datetime(df['first_time_rec1'])).astype('timedelta64[D]') <= 7*n


# In[102]:


for n in range(2,10):
    get_time_to_medida_from_rec1(df_control, n=n)
    get_time_to_medida_from_rec1(df_treatment, n=n)


# In[103]:


def implemented_threats(df, date_index, medida_within_n_weeks=False, 
                        paid_within_n_weeks=False, n=8, 
                        total_due_list_lower=total_due_list_lower,
                        total_due_list_upper=total_due_list_upper):
    list_ratio = []
    list_ratio_unpaid_or_medida = []
    list_num = []
    list_num_plus_medida = []
    if paid_within_n_weeks:
        this_payment_col = 'time_to_payment_from_rec1_le' + str(n) + 'weeks'
    else:
        this_payment_col = payment_cols[-1]
    df['is_rec1_by'] = df[action_cols[date_index]].isin(['rec1','medida'])
    for w1,w2 in zip(total_due_list_lower, total_due_list_upper):
        df_slice = df[df['total_due'].between(w1,w2)].copy()
        if medida_within_n_weeks:
            received_medida = (df_slice[action_cols[-1]]=='medida') & (
                df_slice['time_to_medida_from_rec1_le' + str(n) + 'weeks']==True)
        else:
            received_medida = df_slice[action_cols[-1]]=='medida'
        num_rec1 = sum(df_slice['is_rec1_by'])
        num_rec1_neverpaid = sum(
            df_slice[df_slice['is_rec1_by']==1][this_payment_col]==0
        )
        num_medida = sum(
            received_medida
        )
        num_medida_neverpaid = sum(
            df_slice[received_medida][this_payment_col]==0
        )
        num_rec1_neverpaid_notmedida = sum(
            df_slice[(df_slice['is_rec1_by']) & ~received_medida][this_payment_col]==0
        )
        list_ratio_unpaid_or_medida.append(num_rec1_neverpaid_notmedida/
                                           (num_rec1_neverpaid_notmedida+num_medida))
        list_num.append(num_rec1_neverpaid_notmedida)
        list_num_plus_medida.append(num_rec1_neverpaid_notmedida+num_medida)
    return list_num, list_num_plus_medida, list_ratio_unpaid_or_medida

def generate_output_implemented_threats(date_index, medida_within_n_weeks=False, 
                                        paid_within_n_weeks=False, n=8, 
                                        total_due_list_lower=total_due_list_lower,
                                        total_due_list_upper=total_due_list_upper, file=None):
    control_num_urec1, control_num_urec1_plusmedida, control_prob_no_medida_expanded = implemented_threats(
        df_control, 
        date_index, 
        medida_within_n_weeks=medida_within_n_weeks, 
        paid_within_n_weeks=paid_within_n_weeks,
        n=n, 
        total_due_list_lower=total_due_list_lower,
        total_due_list_upper=total_due_list_upper)
    treatment_num_urec1, treatment_num_urec1_plusmedida, treatment_prob_no_medida_expanded = implemented_threats(
        df_treatment, 
        date_index, 
        medida_within_n_weeks=medida_within_n_weeks, 
        paid_within_n_weeks=paid_within_n_weeks,
        n=n, 
        total_due_list_lower=total_due_list_lower,
        total_due_list_upper=total_due_list_upper)

    print("Numer of unpaid writs and no garnishments for rec1s issued by week", date_index, file=file)
    print("Control:", control_num_urec1, file=file)
    print("Treatment:", treatment_num_urec1, file=file)
    print("Numer of unpaid writs or garnshiments for rec1s issued by week", date_index, 
          (" --- unpaid within " + str(n) + " weeks") * paid_within_n_weeks, 
          (" --- no medida within " + str(n) + " weeks") * medida_within_n_weeks, file=file)
    print("Control:", control_num_urec1_plusmedida, file=file)
    print("Treatment:", treatment_num_urec1_plusmedida, file=file)
    print("Ratio of writs issued by week", date_index, 
          "with garnishment to the sum of the numerator plus the number of unpaid writs issued by week", 
          date_index, 
          (" --- unpaid within " + str(n) + " weeks") * paid_within_n_weeks, 
          (" --- no medida within " + str(n) +" weeks") * medida_within_n_weeks, file=file)
    print("Control:", [np.round(1-x,4) for x in control_prob_no_medida_expanded], file=file)
    print("Treatment:", [np.round(1-x,4) for x in treatment_prob_no_medida_expanded], file=file)

# Table 2
with open("figs/table2.txt", "w") as text_file:
    generate_output_implemented_threats(14,
                                        total_due_list_lower=total_due_quintile_list_lower,
                                        total_due_list_upper=total_due_quintile_list_upper,
                                        file=text_file)


# # Relative payment by treatment

list_tax_due = [99, 1000, 5000, 1e10]

map_payments_treatment = {}
with open("figs/tableOB10.txt", "w") as text_file:
    for i, t in enumerate(list_tax_due[:-1]):
        l, u = list_tax_due[i:i+2]
        payments = filter_data(
            df_treatment, total_due=is_between(l, u))[
            relative_fwdiff_payment_cols].to_numpy().flatten()
        positive_payments = payments[(payments > 0)]
        map_payments_treatment[t] = positive_payments
        print("Treatment (" + str(l) + "," + str(u) + "): " + str(map_payments_treatment[t].mean()), file=text_file)

    map_payments_control = {}
    for i, t in enumerate(list_tax_due[:-1]):
        l, u = list_tax_due[i:i+2]
        payments = filter_data(
            df_control, total_due=is_between(l, u))[
            relative_fwdiff_payment_cols].to_numpy().flatten()
        positive_payments = payments[(payments > 0)]
        map_payments_control[t] = positive_payments
        print("Control (" + str(l) + "," + str(u) + "): " + str(map_payments_control[t].mean()), file=text_file)
